// Copyright © 2019, Oracle and/or its affiliates.
package oci

import (
	"bytes"
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"io/ioutil"
	"net/http"
	"sync"
	"sync/atomic"
	"time"

	"github.com/armon/go-metrics"
	"github.com/hashicorp/go-uuid"
	"github.com/hashicorp/vault/sdk/physical"
	"github.com/oracle/oci-go-sdk/objectstorage"
)

// The lock implementation below prioritizes ensuring that there are not 2 primary at any given point in time
// over high availability of the primary instance

// Verify Backend satisfies the correct interfaces
var (
	_ physical.HABackend = (*Backend)(nil)
	_ physical.Lock      = (*Lock)(nil)
)

const (
	// LockRenewInterval is the time to wait between lock renewals.
	LockRenewInterval = 3 * time.Second

	// LockRetryInterval is the amount of time to wait if the lock fails before trying again.
	LockRetryInterval = 5 * time.Second

	// LockWatchRetryInterval is the amount of time to wait if a watch fails before trying again.
	LockWatchRetryInterval = 2 * time.Second

	// LockTTL is the default lock TTL.
	LockTTL = 15 * time.Second

	// LockWatchRetryMax is the number of times to retry a failed watch before signaling that leadership is lost.
	LockWatchRetryMax = 4

	// LockCacheMinAcceptableAge is minimum cache age in seconds to determine that its safe for a secondary instance
	// to acquire lock.
	LockCacheMinAcceptableAge = 45 * time.Second

	// LockWriteRetriesOnFailures is the number of retries that are made on write 5xx failures.
	LockWriteRetriesOnFailures = 4

	ObjectStorageCallsReadTimeout = 3 * time.Second

	ObjectStorageCallsWriteTimeout = 3 * time.Second
)

type LockCache struct {
	// ETag values are unique identifiers generated by the OCI service and changed every time the object is modified.
	etag       string
	lastUpdate time.Time
	lockRecord *LockRecord
}

type Lock struct {
	// backend is the underlying physical backend.
	backend *Backend

	// Key is the name of the Key. Value is the Value of the Key.
	key, value string

	// held is a boolean indicating if the lock is currently held.
	held bool

	// Identity is the internal Identity of this Key (unique to this server instance).
	identity string

	internalLock sync.Mutex

	// stopCh is the channel that stops all operations. It may be closed in the
	// event of a leader loss or graceful shutdown. stopped is a boolean
	// indicating if we are stopped - it exists to prevent double closing the
	// channel. stopLock is a mutex around the locks.
	stopCh   chan struct{}
	stopped  bool
	stopLock sync.Mutex

	lockRecordCache atomic.Value

	// Allow modifying the Lock durations for ease of unit testing.
	renewInterval      time.Duration
	retryInterval      time.Duration
	ttl                time.Duration
	watchRetryInterval time.Duration
	watchRetryMax      int
}

type LockRecord struct {
	Key      string
	Value    string
	Identity string
}

var (
	metricLockUnlock  = []string{"oci", "lock", "unlock"}
	metricLockLock    = []string{"oci", "lock", "lock"}
	metricLockValue   = []string{"oci", "lock", "Value"}
	metricLeaderValue = []string{"oci", "leader", "Value"}
)

func (b *Backend) HAEnabled() bool {
	return b.haEnabled
}

// LockWith acquires a mutual exclusion based on the given Key.
func (b *Backend) LockWith(key, value string) (physical.Lock, error) {
	identity, err := uuid.GenerateUUID()
	if err != nil {
		return nil, fmt.Errorf("Lock with: %w", err)
	}
	return &Lock{
		backend:  b,
		key:      key,
		value:    value,
		identity: identity,
		stopped:  true,

		renewInterval:      LockRenewInterval,
		retryInterval:      LockRetryInterval,
		ttl:                LockTTL,
		watchRetryInterval: LockWatchRetryInterval,
		watchRetryMax:      LockWatchRetryMax,
	}, nil
}

func (l *Lock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) {
	l.backend.logger.Debug("Lock() called")
	defer metrics.MeasureSince(metricLockLock, time.Now().UTC())
	l.internalLock.Lock()
	defer l.internalLock.Unlock()
	if l.held {
		return nil, errors.New("lock already held")
	}

	// Attempt to lock - this function blocks until a lock is acquired or an error
	// occurs.
	acquired, err := l.attemptLock(stopCh)
	if err != nil {
		return nil, fmt.Errorf("lock: %w", err)
	}
	if !acquired {
		return nil, nil
	}

	// We have the lock now
	l.held = true

	// Build the locks
	l.stopLock.Lock()
	l.stopCh = make(chan struct{})
	l.stopped = false
	l.stopLock.Unlock()

	// Periodically renew and watch the lock
	go l.renewLock()
	go l.watchLock()

	return l.stopCh, nil
}

// attemptLock attempts to acquire a lock. If the given channel is closed, the
// acquisition attempt stops. This function returns when a lock is acquired or
// an error occurs.
func (l *Lock) attemptLock(stopCh <-chan struct{}) (bool, error) {
	l.backend.logger.Debug("AttemptLock() called")
	ticker := time.NewTicker(l.retryInterval)
	defer ticker.Stop()

	for {
		select {
		case <-ticker.C:
			acquired, err := l.writeLock()
			if err != nil {
				return false, fmt.Errorf("attempt lock: %w", err)
			}
			if !acquired {
				continue
			}

			return true, nil
		case <-stopCh:
			return false, nil
		}
	}
}

// renewLock renews the given lock until the channel is closed.
func (l *Lock) renewLock() {
	l.backend.logger.Debug("RenewLock() called")
	ticker := time.NewTicker(l.renewInterval)
	defer ticker.Stop()

	for {
		select {
		case <-ticker.C:
			l.writeLock()
		case <-l.stopCh:
			return
		}
	}
}

func loadLockRecordCache(l *Lock) *LockCache {
	lockRecordCache := l.lockRecordCache.Load()
	if lockRecordCache == nil {
		return nil
	}
	return lockRecordCache.(*LockCache)
}

// watchLock checks whether the lock has changed in the table and closes the
// leader channel accordingly. If an error occurs during the check, watchLock
// will retry the operation and then close the leader channel if it can't
// succeed after retries.
func (l *Lock) watchLock() {
	l.backend.logger.Debug("WatchLock() called")
	retries := 0
	ticker := time.NewTicker(l.watchRetryInterval)
	defer ticker.Stop()

OUTER:
	for {
		// Check if the channel is already closed
		select {
		case <-l.stopCh:
			l.backend.logger.Debug("WatchLock():Stop lock signaled/closed.")
			break OUTER
		default:
		}

		// Check if we've exceeded retries
		if retries >= l.watchRetryMax-1 {
			l.backend.logger.Debug("WatchLock: Failed to get lock data from object storage. Giving up the lease after max retries")
			break OUTER
		}

		// Wait for the timer
		select {
		case <-ticker.C:
		case <-l.stopCh:
			break OUTER
		}

		lockRecordCache := loadLockRecordCache(l)
		if (lockRecordCache == nil) ||
			(lockRecordCache.lockRecord == nil) ||
			(lockRecordCache.lockRecord.Identity != l.identity) ||
			(time.Now().Sub(lockRecordCache.lastUpdate) > l.ttl) {
			l.backend.logger.Debug("WatchLock: Lock record cache is nil, stale or does not belong to self.")
			break OUTER
		}

		lockRecord, _, err := l.get(context.Background())
		if err != nil {
			retries++
			l.backend.logger.Debug("WatchLock: Failed to get lock data from object storage. Retrying..")
			metrics.SetGauge(metricHaWatchLockRetriable, 1)
			continue
		}

		if (lockRecord == nil) || (lockRecord.Identity != l.identity) {
			l.backend.logger.Debug("WatchLock: Lock record cache is nil or does not belong to self.")
			break OUTER
		}

		// reset retries counter on success
		retries = 0
		l.backend.logger.Debug("WatchLock() successful")
		metrics.SetGauge(metricHaWatchLockRetriable, 0)
	}

	l.stopLock.Lock()
	defer l.stopLock.Unlock()
	if !l.stopped {
		l.stopped = true
		l.backend.logger.Debug("Closing the stop channel to give up leadership.")
		close(l.stopCh)
	}
}

func (l *Lock) Unlock() error {
	l.backend.logger.Debug("Unlock() called")
	defer metrics.MeasureSince(metricLockUnlock, time.Now().UTC())

	l.internalLock.Lock()
	defer l.internalLock.Unlock()
	if !l.held {
		return nil
	}

	// Stop any existing locking or renewal attempts
	l.stopLock.Lock()
	if !l.stopped {
		l.stopped = true
		close(l.stopCh)
	}
	l.stopLock.Unlock()

	// We are no longer holding the lock
	l.held = false

	// Get current lock record
	currentLockRecord, etag, err := l.get(context.Background())
	if err != nil {
		return fmt.Errorf("error reading lock record: %w", err)
	}

	if currentLockRecord != nil && currentLockRecord.Identity == l.identity {

		defer metrics.MeasureSince(metricDeleteHa, time.Now())
		opcClientRequestId, err := uuid.GenerateUUID()
		if err != nil {
			l.backend.logger.Debug("Unlock: error generating UUID")
			return fmt.Errorf("failed to generate UUID: %w", err)
		}
		l.backend.logger.Debug("Unlock", "opc-client-request-id", opcClientRequestId)
		request := objectstorage.DeleteObjectRequest{
			NamespaceName:      &l.backend.namespaceName,
			BucketName:         &l.backend.lockBucketName,
			ObjectName:         &l.key,
			IfMatch:            &etag,
			OpcClientRequestId: &opcClientRequestId,
		}

		response, err := l.backend.client.DeleteObject(context.Background(), request)
		l.backend.logRequest("deleteHA", response.RawResponse, response.OpcClientRequestId, response.OpcRequestId, err)

		if err != nil {
			metrics.IncrCounter(metricDeleteFailed, 1)
			return fmt.Errorf("write lock: %w", err)
		}
	}

	return nil
}

func (l *Lock) Value() (bool, string, error) {
	l.backend.logger.Debug("Value() called")
	defer metrics.MeasureSince(metricLockValue, time.Now().UTC())

	lockRecord, _, err := l.get(context.Background())
	if err != nil {
		return false, "", err
	}
	if lockRecord == nil {
		return false, "", err
	}
	return true, lockRecord.Value, nil
}

// get retrieves the Value for the lock.
func (l *Lock) get(ctx context.Context) (*LockRecord, string, error) {
	l.backend.logger.Debug("Called getLockRecord()")

	// Read lock Key

	defer metrics.MeasureSince(metricGetHa, time.Now())
	opcClientRequestId, err := uuid.GenerateUUID()
	if err != nil {
		l.backend.logger.Error("getHa: error generating UUID")
		return nil, "", fmt.Errorf("failed to generate UUID: %w", err)
	}
	l.backend.logger.Debug("getHa", "opc-client-request-id", opcClientRequestId)

	request := objectstorage.GetObjectRequest{
		NamespaceName:      &l.backend.namespaceName,
		BucketName:         &l.backend.lockBucketName,
		ObjectName:         &l.key,
		OpcClientRequestId: &opcClientRequestId,
	}

	ctx, cancel := context.WithTimeout(ctx, ObjectStorageCallsReadTimeout)
	defer cancel()

	response, err := l.backend.client.GetObject(ctx, request)
	l.backend.logRequest("getHA", response.RawResponse, response.OpcClientRequestId, response.OpcRequestId, err)

	if err != nil {
		if response.RawResponse != nil && response.RawResponse.StatusCode == http.StatusNotFound {
			return nil, "", nil
		}

		metrics.IncrCounter(metricGetFailed, 1)
		l.backend.logger.Error("Error calling GET", "err", err)
		return nil, "", fmt.Errorf("failed to read Value for %q: %w", l.key, err)
	}

	defer response.RawResponse.Body.Close()

	body, err := ioutil.ReadAll(response.Content)
	if err != nil {
		metrics.IncrCounter(metricGetFailed, 1)
		l.backend.logger.Error("Error reading content", "err", err)
		return nil, "", fmt.Errorf("failed to decode Value into bytes: %w", err)
	}

	var lockRecord LockRecord
	err = json.Unmarshal(body, &lockRecord)
	if err != nil {
		metrics.IncrCounter(metricGetFailed, 1)
		l.backend.logger.Error("Error un-marshalling content", "err", err)
		return nil, "", fmt.Errorf("failed to read Value for %q: %w", l.key, err)
	}

	return &lockRecord, *response.ETag, nil
}

func (l *Lock) writeLock() (bool, error) {
	l.backend.logger.Debug("WriteLock() called")

	// Create a transaction to read and the update (maybe)
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	// The transaction will be retried, and it could sit in a queue behind, say,
	// the delete operation. To stop the transaction, we close the context when
	// the associated stopCh is received.
	go func() {
		select {
		case <-l.stopCh:
			cancel()
		case <-ctx.Done():
		}
	}()

	lockRecordCache := loadLockRecordCache(l)
	if (lockRecordCache == nil) || lockRecordCache.lockRecord == nil ||
		lockRecordCache.lockRecord.Identity != l.identity ||
		time.Now().Sub(lockRecordCache.lastUpdate) > l.ttl {
		// case secondary
		currentLockRecord, currentEtag, err := l.get(ctx)
		if err != nil {
			return false, fmt.Errorf("error reading lock record: %w", err)
		}

		if (lockRecordCache == nil) || lockRecordCache.etag != currentEtag {
			// update cached lock record
			l.lockRecordCache.Store(&LockCache{
				etag:       currentEtag,
				lastUpdate: time.Now().UTC(),
				lockRecord: currentLockRecord,
			})

			lockRecordCache = loadLockRecordCache(l)
		}

		// Current lock record being null implies that there is no leader. In this case we want to try acquiring lock.
		if currentLockRecord != nil && time.Now().Sub(lockRecordCache.lastUpdate) < LockCacheMinAcceptableAge {
			return false, nil
		}
		// cache is old enough and current, try acquiring lock as secondary
	}

	newLockRecord := &LockRecord{
		Key:      l.key,
		Value:    l.value,
		Identity: l.identity,
	}

	newLockRecordJson, err := json.Marshal(newLockRecord)
	if err != nil {
		return false, fmt.Errorf("error reading lock record: %w", err)
	}

	defer metrics.MeasureSince(metricPutHa, time.Now())

	opcClientRequestId, err := uuid.GenerateUUID()
	if err != nil {
		l.backend.logger.Error("putHa: error generating UUID")
		return false, fmt.Errorf("failed to generate UUID: %w", err)
	}
	l.backend.logger.Debug("putHa", "opc-client-request-id", opcClientRequestId)
	size := int64(len(newLockRecordJson))
	putRequest := objectstorage.PutObjectRequest{
		NamespaceName:      &l.backend.namespaceName,
		BucketName:         &l.backend.lockBucketName,
		ObjectName:         &l.key,
		ContentLength:      &size,
		PutObjectBody:      ioutil.NopCloser(bytes.NewReader(newLockRecordJson)),
		OpcMeta:            nil,
		OpcClientRequestId: &opcClientRequestId,
	}

	if lockRecordCache.etag == "" {
		noneMatch := "*"
		putRequest.IfNoneMatch = &noneMatch
	} else {
		putRequest.IfMatch = &lockRecordCache.etag
	}

	newtEtag := ""
	for i := 1; i <= LockWriteRetriesOnFailures; i++ {
		writeCtx, writeCancel := context.WithTimeout(ctx, ObjectStorageCallsWriteTimeout)
		defer writeCancel()

		putObjectResponse, putObjectError := l.backend.client.PutObject(writeCtx, putRequest)
		l.backend.logRequest("putHA", putObjectResponse.RawResponse, putObjectResponse.OpcClientRequestId, putObjectResponse.OpcRequestId, putObjectError)

		if putObjectError == nil {
			newtEtag = *putObjectResponse.ETag
			putObjectResponse.RawResponse.Body.Close()
			break
		}

		err = putObjectError

		if putObjectResponse.RawResponse == nil {
			metrics.IncrCounter(metricPutFailed, 1)
			l.backend.logger.Error("PUT", "err", err)
			break
		}

		putObjectResponse.RawResponse.Body.Close()

		// Retry if the return code is 5xx
		if (putObjectResponse.RawResponse.StatusCode / 100) == 5 {
			metrics.IncrCounter(metricPutFailed, 1)
			l.backend.logger.Warn("PUT. Retrying..", "err", err)
			time.Sleep(time.Duration(100*i) * time.Millisecond)
		} else {
			l.backend.logger.Error("PUT", "err", err)
			break
		}
	}

	if err != nil {
		return false, fmt.Errorf("write lock: %w", err)
	}

	l.backend.logger.Debug("Lock written", string(newLockRecordJson))

	l.lockRecordCache.Store(&LockCache{
		etag:       newtEtag,
		lastUpdate: time.Now().UTC(),
		lockRecord: newLockRecord,
	})

	metrics.SetGauge(metricLeaderValue, 1)
	return true, nil
}
